{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pseudo-bayesian!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from jax.scipy.stats import norm, expon\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expon.logpdf(3, loc=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as np\n",
    "# True values\n",
    "\n",
    "mu = 1\n",
    "sigma = 3\n",
    "y = 1.1\n",
    "\n",
    "def neg_model_log_prob(params, data):\n",
    "    mu, log_sigma = params\n",
    "    sigma = np.exp(log_sigma)\n",
    "    \n",
    "    # log-probability of mu under prior.\n",
    "    mu_log_prob = norm.logpdf(mu, loc=0, scale=10)\n",
    "\n",
    "    # log-probability of sigma under prior.\n",
    "    sigma_log_prob = expon.logpdf(sigma)\n",
    "\n",
    "    # log-likelihood given priors and data\n",
    "    likelihood_log_prob = np.sum(norm.logpdf(data, loc=mu, scale=sigma))\n",
    "\n",
    "    # Joint log-likelihood\n",
    "    return -(mu_log_prob + sigma_log_prob + likelihood_log_prob)\n",
    "\n",
    "\n",
    "neg_model_log_prob((mu, log_sigma), y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import norm as snorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.autonotebook import tqdm\n",
    "data = snorm(loc=1, scale=3).rvs(10000)\n",
    "\n",
    "neg_model_log_prob((1, 1), data)\n",
    "# neg_model_log_prob((0, 1), data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.optimizers import adam, exponential_decay, inverse_time_decay\n",
    "\n",
    "\n",
    "n_steps = 30000\n",
    "step_size = exponential_decay(0.5, n_steps, 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_size(n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_start = 10\n",
    "sigma_start = 15\n",
    "\n",
    "from jax import grad, jit\n",
    "from jax.experimental.optimizers import adam, exponential_decay\n",
    "\n",
    "\n",
    "dlogp = grad(neg_model_log_prob)\n",
    "\n",
    "opt_init, opt_update, get_params = adam(step_size=0.005, )\n",
    "\n",
    "init_params = (10., 0.5)\n",
    "\n",
    "opt_state = opt_init(init_params)\n",
    "\n",
    "@jit\n",
    "def step(i, opt_state, y):\n",
    "    p = get_params(opt_state)\n",
    "    g = grad(neg_model_log_prob)(p, y)\n",
    "    return opt_update(i, g, opt_state)\n",
    "\n",
    "\n",
    "states = []\n",
    "for i in tqdm(range(n_steps)):\n",
    "    opt_state = step(i, opt_state, data)\n",
    "    if i % 10 == 0:\n",
    "        states.append(opt_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logps = []\n",
    "for state in tqdm(states):\n",
    "    logps.append(neg_model_log_prob(get_params(state), data))\n",
    "\n",
    "\n",
    "params = get_params(opt_state)\n",
    "neg_model_log_prob(params, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(logps)\n",
    "plt.yscale(\"log\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neg_model_log_prob(get_params(states[-1]), data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu, log_sigma = get_params(states[-2050])\n",
    "mu, np.exp(log_sigma)\n",
    "# BINGO! WE DID IT!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl-workshop",
   "language": "python",
   "name": "dl-workshop"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
